from datasets import load_dataset
from distilabel.models import vLLM
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import TextGeneration
from distilabel.steps import (KeepColumns, FormatTextGenerationSFT)
import shutil
import os
import pandas as pd
#import math
#from tqdm import tqdm



pipeline_cache = '/root/.cache/distilabel/pipelines/distill-qwen-32b-r1-tombench'
if os.path.exists(pipeline_cache):
    shutil.rmtree(pipeline_cache)


prompt_template = """\
You will be given a story, a question, and the corresponding choices. Please reason step by step to answer this question, and put your final answer within \\boxed{}:

Story: {{ STORY }}
Question: {{QUESTION}}
Choices: {{containers}}
"""

dataset = load_dataset("json", data_files=".../ToMbench_data/train_combined.json", split="train[:2420]")



def add_combined_column(dataset):
    def combine_text(example):
        
        option_A = example["OPTION-A"]
        option_B = example["OPTION-B"]
        option_C = example["OPTION-C"]
        option_D = example["OPTION-D"]

        formatted_string = ""
        formatted_string += "A: " + option_A + " "
        formatted_string += "B: " + option_B
        if option_C != None:
            formatted_string += " " + "C: " + option_C
        if option_D != None:
            formatted_string += " " + "D: " + option_D

        
        example["containers"] = formatted_string
        example["entire_instruction"] = f"Story: {example['STORY']} Question: {example['QUESTION']} Choices: {formatted_string}"
        return example
    
   
    return dataset.map(combine_text)


dataset = add_combined_column(dataset)
print(dataset)
print(dataset[0])



model_id = ".../distill32B"

with Pipeline(
    name="distill-qwen-32b-r1-tombench",
    description="A pipeline to generate data from a distilled r1 model",
) as pipeline:

    llm = vLLM(
        model=model_id,
        tokenizer=model_id,
        extra_kwargs={
            "tensor_parallel_size": 1,
            "max_model_len": 16384,
        },
        generation_kwargs={
            "temperature": 0.7,
            "max_new_tokens": 16384,
        },
    )


   
    text_generation = TextGeneration(
        llm=llm, 
        template=prompt_template,
        num_generations=2,
        input_batch_size=4,
        columns = ["STORY", "QUESTION", "containers"],
    )

    
    
    format_sft = FormatTextGenerationSFT(input_mappings={"instruction": "entire_instruction"})
   

    text_generation.connect(format_sft)
    



if __name__ == "__main__":
    
    distiset = pipeline.run(dataset=dataset)
    print(distiset)
    print(distiset['default']['train'][0]) 
    distiset.save_to_disk(".../SFTData/ToMBench_May_test")
    distiset.load_from_disk(".../SFTData/ToMBench_May_test")
    print(distiset)


   